{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['finetune-t5-tiny-standard-bahasa-cased/checkpoint-550000',\n",
       " 'finetune-t5-tiny-standard-bahasa-cased/checkpoint-560000',\n",
       " 'finetune-t5-tiny-standard-bahasa-cased/checkpoint-570000',\n",
       " 'finetune-t5-tiny-standard-bahasa-cased/checkpoint-580000',\n",
       " 'finetune-t5-tiny-standard-bahasa-cased/checkpoint-590000',\n",
       " 'finetune-t5-tiny-standard-bahasa-cased/checkpoint-600000',\n",
       " 'finetune-t5-tiny-standard-bahasa-cased/checkpoint-610000',\n",
       " 'finetune-t5-tiny-standard-bahasa-cased/checkpoint-620000',\n",
       " 'finetune-t5-tiny-standard-bahasa-cased/checkpoint-630000',\n",
       " 'finetune-t5-tiny-standard-bahasa-cased/checkpoint-640000']"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from glob import glob\n",
    "\n",
    "checkpoints = sorted(glob('finetune-t5-tiny-standard-bahasa-cased/checkpoint-*'))\n",
    "checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
    "\n",
    "tokenizer = T5Tokenizer.from_pretrained('mesolitica/t5-small-standard-bahasa-cased')\n",
    "model = T5ForConditionalGeneration.from_pretrained(checkpoints[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<pad> Hai lelaki! Saya perhatikan semalam & hari ini banyak yang dapat cookies kan. Jadi hari ini saya ingin berkongsi beberapa post mortem kumpulan pertama kami:</s>\n"
     ]
    }
   ],
   "source": [
    "input_ids = tokenizer.encode('terjemah Inggeris ke Melayu: Hi guys! I noticed semalam & harini dah ramai yang dapat cookies ni kan. So harini i nak share some post mortem of our first batch:', return_tensors = 'pt')\n",
    "outputs = model.generate(input_ids, max_length = 100)\n",
    "print(tokenizer.decode(outputs[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.push_to_hub('finetune-translation-t5-tiny-standard-bahasa-cased', organization='mesolitica')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.push_to_hub('finetune-translation-t5-tiny-standard-bahasa-cased', organization='mesolitica')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sacrebleu.metrics import BLEU, CHRF, TER\n",
    "\n",
    "bleu = BLEU()\n",
    "chrf = CHRF(word_order = 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from unidecode import unidecode\n",
    "\n",
    "with open('eng_Latn.dev') as fopen:\n",
    "    eng = fopen.read().split('\\n')[:-1]\n",
    "    \n",
    "with open('zsm_Latn.dev') as fopen:\n",
    "    ms = fopen.read().split('\\n')[:-1]\n",
    "    \n",
    "right = [unidecode(s) for s in ms]\n",
    "left = [unidecode(s) for s in eng]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████| 997/997 [02:06<00:00,  7.85it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "batch_size = 1\n",
    "\n",
    "results = []\n",
    "for i in tqdm(range(0, len(left), batch_size)):\n",
    "    input_ids = [{'input_ids': tokenizer.encode(f'terjemah Inggeris ke Melayu: {s}', return_tensors = 'pt')[0]} for s in left[i:i + batch_size]]\n",
    "    padded = tokenizer.pad(input_ids, padding = 'longest')\n",
    "    outputs = model.generate(**padded, max_length = 1000)\n",
    "    for o in outputs:\n",
    "        results.append(tokenizer.decode(o, skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_left, filtered_right = [], []\n",
    "for no, r in enumerate(results):\n",
    "    if len(r):\n",
    "        filtered_left.append(r)\n",
    "        filtered_right.append(right[no])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "refs = [filtered_right]\n",
    "sys = filtered_left"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "r = bleu.corpus_score(sys, refs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'name': 'BLEU',\n",
       " 'score': 41.625536185056305,\n",
       " '_mean': -1.0,\n",
       " '_ci': -1.0,\n",
       " '_verbose': '73.4/50.1/35.7/25.7 (BP = 0.971 ratio = 0.972 hyp_len = 21400 ref_len = 22027)',\n",
       " 'bp': 0.9711259908305946,\n",
       " 'counts': [15718, 10223, 6926, 4731],\n",
       " 'totals': [21400, 20403, 19406, 18409],\n",
       " 'sys_len': 21400,\n",
       " 'ref_len': 22027,\n",
       " 'precisions': [73.44859813084112,\n",
       "  50.10537666029506,\n",
       "  35.68999278573637,\n",
       "  25.699386169808246],\n",
       " 'prec_str': '73.4/50.1/35.7/25.7',\n",
       " 'ratio': 0.9715349343986925}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r.__dict__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "chrF2++ = 65.70"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chrf.corpus_score(sys, refs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('eng_Latn.dev') as fopen:\n",
    "    eng = fopen.read().split('\\n')[:-1]\n",
    "    \n",
    "with open('zsm_Latn.dev') as fopen:\n",
    "    ms = fopen.read().split('\\n')[:-1]\n",
    "    \n",
    "left = [unidecode(s) for s in ms]\n",
    "right = [unidecode(s) for s in eng]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████| 997/997 [02:20<00:00,  7.10it/s]\n"
     ]
    }
   ],
   "source": [
    "batch_size = 1\n",
    "\n",
    "results = []\n",
    "for i in tqdm(range(0, len(left), batch_size)):\n",
    "    input_ids = [{'input_ids': tokenizer.encode(f'terjemah Melayu ke Inggeris: {s}', return_tensors = 'pt')[0]} for s in left[i:i + batch_size]]\n",
    "    padded = tokenizer.pad(input_ids, padding = 'longest')\n",
    "    outputs = model.generate(**padded, max_length = 1000)\n",
    "    for o in outputs:\n",
    "        results.append(tokenizer.decode(o, skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_left, filtered_right = [], []\n",
    "for no, r in enumerate(results):\n",
    "    if len(r):\n",
    "        filtered_left.append(r)\n",
    "        filtered_right.append(right[no])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "refs = [filtered_right]\n",
    "sys = filtered_left"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'name': 'BLEU',\n",
       " 'score': 37.26048464066508,\n",
       " '_mean': -1.0,\n",
       " '_ci': -1.0,\n",
       " '_verbose': '68.3/44.1/30.5/21.4 (BP = 0.995 ratio = 0.995 hyp_len = 23457 ref_len = 23570)',\n",
       " 'bp': 0.9951942593830536,\n",
       " 'counts': [16020, 9908, 6547, 4376],\n",
       " 'totals': [23457, 22460, 21463, 20466],\n",
       " 'sys_len': 23457,\n",
       " 'ref_len': 23570,\n",
       " 'precisions': [68.29517841156158,\n",
       "  44.1139804096171,\n",
       "  30.503657457019056,\n",
       "  21.381803967555946],\n",
       " 'prec_str': '68.3/44.1/30.5/21.4',\n",
       " 'ratio': 0.9952057700466695}"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r = bleu.corpus_score(sys, refs)\n",
    "r.__dict__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "chrF2++ = 61.29"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chrf.corpus_score(sys, refs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
